{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Road Following - Build TensorRT model for live demo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will optimize the model we trained using TensorRT."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the trained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will assume that you have already downloaded ``best_steering_model_xy.pth`` to work station as instructed in \"train_model.ipynb\" notebook. Now, you should upload model file to JetBot in to this notebook's directory. Once that's finished there should be a file named ``best_steering_model_xy.pth`` in this notebook's directory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Please make sure the file has uploaded fully before calling the next cell"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the code below to initialize the PyTorch model. This should look very familiar from the training notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torch\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=False)\n",
    "model.fc = torch.nn.Linear(512, 2)\n",
    "model = model.cuda().eval().half()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, load the trained weights from the ``best_steering_model_xy.pth`` file that you uploaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load('best_steering_model_xy.pth'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Currently, the model weights are located on the CPU memory execute the code below to transfer to the GPU device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorRT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> If your setup does not have `torch2trt` installed, you need to first install `torch2trt` by executing the following in the console.\n",
    "```bash\n",
    "cd $HOME\n",
    "git clone https://github.com/NVIDIA-AI-IOT/torch2trt\n",
    "cd torch2trt\n",
    "sudo python3 setup.py install\n",
    "```\n",
    "\n",
    "Convert and optimize the model using torch2trt for faster inference with TensorRT. Please see the [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) readme for more details.\n",
    "\n",
    "> This optimization process can take a couple minutes to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch2trt import torch2trt\n",
    "\n",
    "data = torch.zeros((1, 3, 224, 224)).cuda().half()\n",
    "\n",
    "model_trt = torch2trt(model, [data], fp16_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save the optimized model using the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model_trt.state_dict(), 'best_steering_model_xy_trt.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "Open live_demo_trt.ipynb to move JetBot with the TensorRT optimized model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
